-
-
Notifications
You must be signed in to change notification settings - Fork 11k
Add token sharding functions for context parallel #26058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add token sharding functions for context parallel #26058
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces utility functions for token sharding to enable context parallelism, which is a significant feature for improving performance. The changes correctly add the context_parallel_size configuration and update the distributed state management accordingly. The core logic for sharding is well-encapsulated in the new vllm/v1/attention/backends/cp_utils.py file. However, I've identified a couple of issues in the new test file, tests/v1/attention/test_context_parallel_attention.py, including a critical bug in an assertion that needs to be fixed for the tests to be valid.
| assert num_comp_local == [ | ||
| num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2] | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a bug in this assertion. The expected value for num_comp_local should be a list of integers, but the expression [num_computed_tokens[1][-1] // 2] creates a list as the second element, resulting in [5, [4]]. The actual value of num_comp_local is [5, 4], which will cause this assertion to fail. For clarity and correctness, it's better to assert against the hardcoded expected value.
assert num_comp_local == [5, 4]|
|
||
| def make_cached_request_state(id: int, prefill_len: int, decode_len: int, | ||
| num_computed_tokens: list[int]): | ||
| assert prefill_len + decode_len == sum(num_computed_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion in this helper function is incorrect. num_computed_tokens is a list of cumulative token counts, so sum(num_computed_tokens) does not represent the total number of tokens. The total number of tokens is the last element of the list. The assertion should be assert prefill_len + decode_len == num_computed_tokens[-1].
| assert prefill_len + decode_len == sum(num_computed_tokens) | |
| assert prefill_len + decode_len == num_computed_tokens[-1] |
aa10cb2 to
7a51dd0
Compare
9fccf80 to
d07fd25
Compare
d07fd25 to
f6b6ed3
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Add utility functions to enable load-balanced token sharding for context parallelism.
Purpose
Causal attention imposes a varying computational load for each token, as shown in the following figure. To ensure an even workload distribution, tokens should be partitioned across different context parallelism (CP) ranks. Specifically, the sequence is divided into 2 × cp_world_size chunks. Each CP rank i is assigned both the i-th chunk and the (2 × cp_world_size - i - 1)-th chunk. This approach helps balance the compute load among all CP ranks.

When tokens are distributed across context parallel (CP) ranks, gaps may appear in the block table. After compaction, tokens that are stored physically next to each other may not be logically consecutive. This is acceptable for CP because we only need to maintain the correct relative order of tokens for mapping purposes, rather than tracking their absolute positions in the block table.

Test Plan
E2e tests to be added in the following PRs.
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.